{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Inference with your model\n",
    "\n",
    "This is the third and final tutorial of our [beginner tutorial series](https://github.com/awslabs/djl/tree/master/jupyter/tutorial) that will take you through creating, training, and running inference on a neural network. In this tutorial, you will learn how to execute your image classification model for a production system.\n",
    "\n",
    "In the [previous tutorial](02_train_your_first_model.ipynb), you successfully trained your model. Now, we will learn how to implement a `Translator` to convert between POJO and `NDArray` as well as a `Predictor` to run inference.\n",
    "\n",
    "\n",
    "## Preparation\n",
    "\n",
    "This tutorial requires the installation of the Java Jupyter Kernel. To install the kernel, see the [Jupyter README](https://github.com/awslabs/djl/blob/master/jupyter/README.md)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "// Add the snapshot repository to get the DJL snapshot artifacts\n",
    "// %mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/\n",
    "\n",
    "// Add the maven dependencies\n",
    "%maven ai.djl:api:0.8.0\n",
    "%maven ai.djl:model-zoo:0.8.0\n",
    "%maven ai.djl.mxnet:mxnet-engine:0.8.0\n",
    "%maven ai.djl.mxnet:mxnet-model-zoo:0.8.0\n",
    "%maven org.slf4j:slf4j-api:1.7.26\n",
    "%maven org.slf4j:slf4j-simple:1.7.26\n",
    "%maven net.java.dev.jna:jna:5.3.0\n",
    "    \n",
    "// See https://github.com/awslabs/djl/blob/master/mxnet/mxnet-engine/README.md\n",
    "// for more MXNet library selection options\n",
    "%maven ai.djl.mxnet:mxnet-native-auto:1.7.0-backport"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import java.awt.image.*;\n",
    "import java.nio.file.*;\n",
    "import java.util.*;\n",
    "import java.util.stream.*;\n",
    "import ai.djl.*;\n",
    "import ai.djl.basicmodelzoo.basic.*;\n",
    "import ai.djl.ndarray.*;\n",
    "import ai.djl.modality.*;\n",
    "import ai.djl.modality.cv.*;\n",
    "import ai.djl.modality.cv.util.NDImageUtils;\n",
    "import ai.djl.translate.*;"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 1: Load your handwritten digit image\n",
    "\n",
    "We will start by loading the image that we want to run our model to classify."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "var img = ImageFactory.getInstance().fromUrl(\"https://djl-ai.s3.amazonaws.com/resources/images/0.png\");\n",
    "img.getWrappedImage();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 2: Load your model\n",
    "\n",
    "Next, we need to load the model to run inference with. This model should have been saved to the `build/mlp` directory when running the [previous tutorial](02_train_your_first_model.ipynb)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Path modelDir = Paths.get(\"build/mlp\");\n",
    "Model model = Model.newInstance(\"mlp\");\n",
    "model.setBlock(new Mlp(28 * 28, 10, new int[] {128, 64}));\n",
    "model.load(modelDir);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In addition to loading a local model, you can also find pretrained models within our [model zoo](http://docs.djl.ai/docs/model-zoo.html). See more options in our [model loading documentation](http://docs.djl.ai/docs/load_model.html).\n",
    "\n",
    "## Step 3: Create a `Translator`\n",
    "\n",
    "The [`Translator`](https://javadoc.io/static/ai.djl/api/0.8.0/index.html?ai/djl/translate/Translator.html) is used to encapsulate the pre-processing and post-processing functionality of your application. The input to the processInput and processOutput should be single data items, not batches."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Translator<Image, Classifications> translator = new Translator<Image, Classifications>() {\n",
    "\n",
    "    @Override\n",
    "    public NDList processInput(TranslatorContext ctx, Image input) {\n",
    "        // Convert Image to NDArray\n",
    "        NDArray array = input.toNDArray(ctx.getNDManager(), Image.Flag.GRAYSCALE);\n",
    "        return new NDList(NDImageUtils.toTensor(array));\n",
    "    }\n",
    "\n",
    "    @Override\n",
    "    public Classifications processOutput(TranslatorContext ctx, NDList list) {\n",
    "        // Create a Classifications with the output probabilities\n",
    "        NDArray probabilities = list.singletonOrThrow().softmax(0);\n",
    "        List<String> classNames = IntStream.range(0, 10).mapToObj(String::valueOf).collect(Collectors.toList());\n",
    "        return new Classifications(classNames, probabilities);\n",
    "    }\n",
    "    \n",
    "    @Override\n",
    "    public Batchifier getBatchifier() {\n",
    "        // The Batchifier describes how to combine a batch together\n",
    "        // Stacking, the most common batchifier, takes N [X1, X2, ...] arrays to a single [N, X1, X2, ...] array\n",
    "        return Batchifier.STACK;\n",
    "    }\n",
    "};"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 4: Create Predictor\n",
    "\n",
    "Using the translator, we will create a new [`Predictor`](https://javadoc.io/static/ai.djl/api/0.8.0/index.html?ai/djl/inference/Predictor.html). The predictor is the main class to orchestrate the inference process. During inference, a trained model is used to predict values, often for production use cases. The predictor is NOT thread-safe, so if you want to do prediction in parallel, you should call newPredictor multiple times to create a predictor object for each thread."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "var predictor = model.newPredictor(translator);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 5: Run inference\n",
    "\n",
    "With our predictor, we can simply call the [predict](https://javadoc.io/static/ai.djl/api/0.8.0/ai/djl/inference/Predictor.html#predict-I-) method to run inference. For better performance, you can also call [batchPredict](https://javadoc.io/static/ai.djl/api/0.8.0/ai/djl/inference/Predictor.html#batchPredict-java.util.List-) with a list of input items. Afterwards, the same predictor should be used for further inference calls. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "var classifications = predictor.predict(img);\n",
    "\n",
    "classifications"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Summary\n",
    "\n",
    "Now, you've successfully built a model, trained it, and run inference. Congratulations on finishing the [beginner tutorial series](https://github.com/awslabs/djl/tree/master/jupyter/tutorial). After this, you should read our other [examples](https://github.com/awslabs/djl/tree/master/examples) and [jupyter notebooks](https://github.com/awslabs/djl/tree/master/jupyter) to learn more about DJL.\n",
    "\n",
    "You can find the complete source code for this tutorial in the [examples project](https://github.com/awslabs/djl/blob/master/examples/src/main/java/ai/djl/examples/inference/ImageClassification.java)."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Java",
   "language": "java",
   "name": "java"
  },
  "language_info": {
   "codemirror_mode": "java",
   "file_extension": ".jshell",
   "mimetype": "text/x-java-source",
   "name": "Java",
   "pygments_lexer": "java",
   "version": "12.0.2+10"
  },
  "pycharm": {
   "stem_cell": {
    "cell_type": "raw",
    "metadata": {
     "collapsed": false
    },
    "source": []
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
